-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General offset computation for dynamic slice fusion #16879
General offset computation for dynamic slice fusion #16879
Conversation
1ca6525
to
470a734
Compare
Ping for review! |
// CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]]) | ||
// CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}}) | ||
)")); | ||
ASSERT_TRUE(filecheck_fusion); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: EXPECT_TRUE instead of ASSERT_TRUE
If this check fails, it might still make sense to continue and run the next filecheck
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom | ||
// CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}}) | ||
// CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]]) | ||
)")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing EXPECT_TRUE(filecheck_while_loop);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
<< computation->name(); | ||
std::unique_ptr<HloComputation> cloned_computation = | ||
const_cast<HloComputation*>(computation)->Clone("main", &context); | ||
HloPrintOptions options; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines (up to line 205) are probably left from debugging? Please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
bool GetOperationModule(const HloInstruction* op, HloModule* module, | ||
const HloInstruction* indvar) { | ||
HloCloneContext context(module); | ||
CHECK(!module->has_entry_computation()) << "Expected empty module"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you can reuse xla/tools/hlo_extractor.h for this. You can do the check whether the offset is derived from indvar on the original computation. And then Call ExtractModule from hlo_extractor.h with op
as parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I changed it.
// and `updated_i = g(i)`. If such a computation cannot be created within the | ||
// module (maybe the offset depends on partition-id or something else), then it | ||
// returns false. If it is successful, it returns true. | ||
bool GetOperationModule(const HloInstruction* op, HloModule* module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this method just return std::unique_ptr<HloModule>
instead of bool? You can let it return nullptr in case it doesn't succeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this function as suggested.
470a734
to
4291c45
Compare
4291c45
to
bcf8a34
Compare
} | ||
|
||
return true; | ||
// Extract the offset and update modules and verify that they only take the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this took me quite some time to understand why we don't need to check that loop_indvar
is actually a direct or indirect operand of idx
. This probably deserves a comment.
So my reasoning was: we know that the computation from which we extract the module only has a single parameter (a tuple), and loop_indvar
is a GetTupleElement user of that parameter. ExtractHloModule will create a parameter if it encounters loop_indvar
and will not visit its operand. If it does not encounter loop_indvar
, it will stop at the tuple parameter. Therefore the check in line 241 that the single parameter of the module has the expected shape (not the tuple shape) is sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. And made the check stricter. Thanks for pointing this out.
}); | ||
// This function takes an array literal and gives a constant instruction with | ||
// that literal. | ||
std::unique_ptr<HloInstruction> GetAsConstantInstruction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified using stuff from primitive_util.h:
Line 496 in c4ae62b
constexpr R PrimitiveTypeSwitch(F&& f, PrimitiveType type) { |
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't use the PrimitiveTypeToNative
- it requires the template argument to be a constant, and we would like that argument to be offset_values[0].shape().element_type()
. But I used NativeToPrimitiveType
. Let me know if you have something nicer in mind. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It only works in combination with PrimitiveTypeSwitch. This will forward the different possible PrimitiveType values as constants to your templated function. Sorry if that wasn't clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here is a pointer to existing code that should be helpful:
https://github.com/openxla/xla/blob/main/xla/service/hlo_parser.cc#L4207
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the recommendation!
6c04cd2
to
2d5f03a
Compare
2d5f03a
to
6369888
Compare
This patch adds logic for computing general offset while creating dynamic-slice-fusion.
Imported from GitHub PR #16879 This patch adds logic for computing general offset while creating dynamic-slice-fusion. Copybara import of the project: -- 6369888 by Shraiysh Vaishay <[email protected]>: General offset computation for dynamic slice fusion This patch adds logic for computing general offset while creating dynamic-slice-fusion. Merging this change closes #16879 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16879 from shraiysh:general_offset_dynamic_slice_fusion 6369888 PiperOrigin-RevId: 676711783
Imported from GitHub PR openxla/xla#16879 This patch adds logic for computing general offset while creating dynamic-slice-fusion. Copybara import of the project: -- 636988837ad87f0bb0b4906ff5653e6982a5cad0 by Shraiysh Vaishay <[email protected]>: General offset computation for dynamic slice fusion This patch adds logic for computing general offset while creating dynamic-slice-fusion. Merging this change closes #16879 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16879 from shraiysh:general_offset_dynamic_slice_fusion 636988837ad87f0bb0b4906ff5653e6982a5cad0 PiperOrigin-RevId: 676711783
Imported from GitHub PR #16879 This patch adds logic for computing general offset while creating dynamic-slice-fusion. Copybara import of the project: -- 6369888 by Shraiysh Vaishay <[email protected]>: General offset computation for dynamic slice fusion This patch adds logic for computing general offset while creating dynamic-slice-fusion. Merging this change closes #16879 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16879 from shraiysh:general_offset_dynamic_slice_fusion 6369888 PiperOrigin-RevId: 676711783
Imported from GitHub PR openxla/xla#16879 This patch adds logic for computing general offset while creating dynamic-slice-fusion. Copybara import of the project: -- 636988837ad87f0bb0b4906ff5653e6982a5cad0 by Shraiysh Vaishay <[email protected]>: General offset computation for dynamic slice fusion This patch adds logic for computing general offset while creating dynamic-slice-fusion. Merging this change closes #16879 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16879 from shraiysh:general_offset_dynamic_slice_fusion 636988837ad87f0bb0b4906ff5653e6982a5cad0 PiperOrigin-RevId: 676711783
Imported from GitHub PR openxla/xla#16879 This patch adds logic for computing general offset while creating dynamic-slice-fusion. Copybara import of the project: -- 636988837ad87f0bb0b4906ff5653e6982a5cad0 by Shraiysh Vaishay <[email protected]>: General offset computation for dynamic slice fusion This patch adds logic for computing general offset while creating dynamic-slice-fusion. Merging this change closes #16879 PiperOrigin-RevId: 676813991
This PR was rolled back in aace801! |
This patch adds logic for computing general offset while creating dynamic-slice-fusion.